package ollamarunner

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"hash/maphash"
	"image"
	"log"
	"log/slog"
	"net"
	"net/http"
	"os"
	"reflect"
	"regexp"
	"runtime"
	"runtime/debug"
	"strconv"
	"strings"
	"sync"
	"time"
	"unicode/utf8"

	"golang.org/x/image/bmp"
	"golang.org/x/sync/semaphore"

	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/envconfig"
	"github.com/ollama/ollama/llm"
	"github.com/ollama/ollama/logutil"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/model"
	"github.com/ollama/ollama/model/input"
	"github.com/ollama/ollama/runner/common"
	"github.com/ollama/ollama/sample"

	_ "github.com/ollama/ollama/model/models"
)

type Sequence struct {
	// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
	// multimodal embeddings
	ctxs []ml.Context

	// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
	mmStore multimodalStore

	// batch index
	iBatch int

	// prompt inputs left to evaluate
	inputs []*input.Input

	// inputs that have been added to a batch but not yet submitted to Forward
	pendingInputs []*input.Input

	// tokens that have been generated but not returned yet (e.g. for stop sequences)
	pendingResponses []string

	// input cache being used by this sequence
	cache *InputCacheSlot

	// channel to send responses over
	responses chan string

	// channel to stop decoding (such as if the remote connection is closed)
	quit chan bool

	// number of tokens to predict
	numPredict int

	// sampler with transforms to run on generated logits
	sampler sample.Sampler

	// channel to send back the embedding if embedding only
	embedding chan []float32

	// stop sequences
	stop []string

	// number of inputs to keep at the beginning when shifting context window
	numKeep int32

	// true if an embedding are to be returned instead of text generation
	embeddingOnly bool

	doneReason llm.DoneReason

	// Metrics
	startProcessingTime time.Time
	startGenerationTime time.Time
	numPredicted        int
	numPromptInputs     int
}

type NewSequenceParams struct {
	numPredict int
	stop       []string
	numKeep    int32
	sampler    sample.Sampler
	embedding  bool
}

func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
	s.ready.Wait()

	startTime := time.Now()

	inputs, ctxs, mmStore, err := s.inputs(prompt, images)
	if err != nil {
		return nil, fmt.Errorf("failed to process inputs: %w", err)
	} else if len(inputs) == 0 {
		return nil, errors.New("no input provided")
	}

	if params.numKeep < 0 {
		params.numKeep = int32(len(inputs))
	}

	// Ensure that at least 1 input can be discarded during shift
	params.numKeep = min(params.numKeep, s.cache.numCtx-1)

	if int32(len(inputs)) > s.cache.numCtx {
		discard := int32(len(inputs)) - s.cache.numCtx
		promptStart := params.numKeep + discard

		// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
		sameBatch := 0
		for i, inp := range inputs {
			if sameBatch > 0 {
				sameBatch--

				if promptStart == int32(i) {
					promptStart++
				}
			} else if promptStart == int32(i) {
				break
			}

			if inp.SameBatch != 0 {
				if int32(i) < params.numKeep {
					return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
				}

				sameBatch = inp.SameBatch
			}
		}

		if promptStart >= int32(len(inputs)) {
			return nil, errors.New("entire prompt removed by truncation")
		}

		newInputs := inputs[:params.numKeep]
		newInputs = append(newInputs, inputs[promptStart:]...)

		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
		inputs = newInputs
	}

	// TODO(jessegross): Ingest cached history for grammar

	return &Sequence{
		ctxs:                ctxs,
		mmStore:             mmStore,
		inputs:              inputs,
		numPromptInputs:     len(inputs),
		startProcessingTime: startTime,
		numPredict:          params.numPredict,
		pendingResponses:    make([]string, 0),
		responses:           make(chan string, 100),
		quit:                make(chan bool, 1),
		embedding:           make(chan []float32, 1),
		sampler:             params.sampler,
		embeddingOnly:       params.embedding,
		stop:                params.stop,
		numKeep:             params.numKeep,
	}, nil
}

// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
	var inputs []*input.Input
	var ctxs []ml.Context
	var mmStore multimodalStore

	var parts []string
	var matches [][]string

	multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)

	if visionModel {
		re := regexp.MustCompile(`\[img-(\d+)\]`)
		parts = re.Split(prompt, -1)
		matches = re.FindAllStringSubmatch(prompt, -1)
		mmStore = newMultimodalStore()
	} else {
		parts = []string{prompt}
	}

	postTokenize := false
	for i, part := range parts {
		// text - tokenize
		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
		if err != nil {
			return nil, nil, nil, err
		}

		for _, t := range tokens {
			inputs = append(inputs, &input.Input{Token: t})
		}

		// image - decode and store
		if i < len(matches) {
			n, _ := strconv.Atoi(matches[i][1])

			imageIndex := -1
			for j := range images {
				if images[j].ID == n {
					imageIndex = j
					break
				}
			}

			if imageIndex < 0 {
				return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
			}

			ctx := s.model.Backend().NewContext()
			runtime.SetFinalizer(ctx, func(c ml.Context) { c.Close() })
			ctxs = append(ctxs, ctx)
			imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
			if err != nil {
				return nil, nil, nil, err
			}

			s.multimodalHash.Reset()
			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
			imageHash := s.multimodalHash.Sum64()

			mmStore.addMultimodal(imageEmbeddings)

			inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
			postTokenize = true
		}
	}

	if visionModel && postTokenize {
		var err error
		inputs, err = multimodalProcessor.PostTokenize(inputs)
		if err != nil {
			return nil, nil, nil, err
		}
	}

	return inputs, ctxs, mmStore, nil
}

type batchState struct {
	// id provides a counter for trace logging batches
	id int

	// ctx holds the backend context used for this batch
	ctx ml.Context

	// modelOutput holds the outputs from this batch
	modelOutput ml.Tensor

	// batchInputs holds the input token pointers which may start as
	// placeholders later filled in before calling ctx.Compute
	batchInputs []*input.Input

	// batch contains the inputs for a model forward pass
	batch input.Batch

	// full set of seqs at the time this batch was initiated
	seqs []*Sequence

	// Signaled when this batches inputs are ready and compute can proceed
	inputsReadyCh chan struct{}

	// Signaling when Compute is about to begin on this batch, and
	// seqs have been updated to prepare for the next batch
	computeStartedCh chan struct{}

	// Signaled when this batches outputs are complete and the next batch can proceed
	outputsReadyCh chan struct{}
}

type Server struct {
	// modelPath is the location of the model to be loaded
	modelPath string

	// loadMu prevents more than one load attempt from occurring at a time
	loadMu sync.Mutex

	// lastLoad is the load request from the previous load attempt. Used to
	// detect if we can reuse an existing memory allocation.
	lastLoad llm.LoadRequest

	// is the server ready to process requests?
	// protects access to model and image
	ready sync.WaitGroup

	// loaded model
	model model.Model

	// status for external health reporting - loading, ready to serve, etc.
	status llm.ServerStatus

	// current progress on loading the model
	progress float32

	// number of simultaneous requests to handle
	parallel int

	// maximum number of elements in a batch (per sequence)
	// TODO (jmorganca): make this n_batch
	batchSize int

	// Used to signal a hard failure during async processing which will panic the runner
	hardErrCh chan error

	// Simple counter used only for trace logging batches
	batchID int

	// protects access to everything below this line
	// this is context state needed for decoding
	mu sync.Mutex

	// indicates that data is ready for processing
	cond *sync.Cond

	// the list of simultaneous sequences being evaluated
	seqs []*Sequence

	// seqs can have a maximum of parallel entries, which
	// is enfoced by seqSem
	seqsSem *semaphore.Weighted

	// KV cache
	cache *InputCache

	// next sequence for prompt processing to avoid starvation
	nextSeq int

	// multimodalHash generates hashes for comparing equality
	// of non-text data
	multimodalHash maphash.Hash
}

func (s *Server) allNil() bool {
	for _, item := range s.seqs {
		if item != nil {
			return false
		}
	}
	return true
}

func flushPending(seq *Sequence) bool {
	joined := strings.Join(seq.pendingResponses, "")
	seq.pendingResponses = []string{}

	// Check if there are any partial UTF-8 characters remaining.
	// We already check and queue as we are generating but some may
	// still make it here:
	// - Sequence is ending, e.g. generation limit has been hit
	// - Invalid characters in the middle of a string
	// This is a stricter check to ensure we never output invalid Unicode.
	for !utf8.ValidString(joined) {
		joined = joined[:len(joined)-1]
	}

	if len(joined) == 0 {
		return true
	}

	select {
	case seq.responses <- joined:
		return true
	case <-seq.quit:
		return false
	}
}

func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
	seq := s.seqs[seqIndex]

	flushPending(seq)
	seq.doneReason = reason
	close(seq.responses)
	close(seq.embedding)
	seq.cache.InUse = false
	s.seqs[seqIndex] = nil
	s.seqsSem.Release(1)
}

// track batch state between forwardBatch, computeBatch and predictForwardBatch

func (s *Server) run(ctx context.Context) {
	s.ready.Wait()

	var activeBatch batchState
	for {
		select {
		case <-ctx.Done():
			return
		case err := <-s.hardErrCh:
			panic(err)
		default:
			var err error
			activeBatch, err = s.forwardBatch(activeBatch)
			if err != nil {
				panic(err)
			}
			go s.computeBatch(activeBatch)
		}
	}
}

// forwardBatch will calculate a batch.
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
	// If we have a pending batch still processing, wait until Compute has started
	// before setting up the next batch so the seqs inputs are ready to receive their
	// token values and we get the correct input pointers for the batchInputs
	if pendingBatch.ctx != nil {
		slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
		<-pendingBatch.computeStartedCh
		slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
		nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
	} else {
		slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
		// No pendingBatch, so the inputs will be ready in the seqs immediately
		nextBatch.inputsReadyCh = make(chan struct{}, 1)
		nextBatch.inputsReadyCh <- struct{}{}
	}

	s.mu.Lock()
	for s.allNil() {
		s.cond.Wait() // Wait until an item is added
	}
	defer s.mu.Unlock()

	nextBatch.ctx = s.model.Backend().NewContext()
	defer func() {
		if err != nil {
			nextBatch.ctx.Close()
			nextBatch.ctx = nil
		}
	}()
	nextBatch.id = s.batchID
	nextBatch.seqs = append([]*Sequence{}, s.seqs...)
	nextBatch.computeStartedCh = make(chan struct{}, 1)
	nextBatch.outputsReadyCh = make(chan struct{}, 1)

	// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
	var batchInputs []*input.Input
	var batch input.Batch

	resumeSeq := -1
	seqIdx := s.nextSeq - 1
	for range s.seqs {
		seqIdx = (seqIdx + 1) % len(s.seqs)
		seq := s.seqs[seqIdx]
		if seq == nil {
			continue
		}

		// if past the num predict limit
		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
			s.removeSequence(seqIdx, llm.DoneReasonLength)
			nextBatch.seqs[seqIdx] = nil
			continue
		}

		if !s.cache.enabled {
			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
			seq.cache.Inputs = []*input.Input{}
		}

		batchSize := s.batchSize

		for i, inp := range seq.inputs {
			// If we are required to put following inputs into a single batch then extend the
			// batch size. Since we are only extending the size the minimum amount possible, this
			// will cause a break if we have existing inputs.
			minBatch := 1 + inp.SameBatch
			if minBatch > batchSize {
				batchSize = minBatch
			}

			// Stop if the required batch would put us over the total batch size (including tokens
			// added by other sequences). If we haven't been able to add anything yet then pick up
			// here again for the next batch to avoid starvation, though we can opportunistically
			// check if other sequences can still squeeze something in.
			if len(batchInputs)+minBatch > batchSize {
				if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
					resumeSeq = seqIdx
				}
				break
			}

			// If the sum of our working set (already processed tokens, tokens we added to this
			// batch, required following tokens) exceeds the context size, then trigger a shift
			// now so we don't have to do one later when we can't break the batch.
			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
				if len(seq.pendingInputs) != 0 {
					break
				}

				err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
				if err != nil {
					var reprocess *ErrReprocessInputs
					if errors.As(err, &reprocess) {
						// Prepend these inputs to the sequence's inputs queue for reprocessing
						seq.inputs = append(reprocess.Inputs, seq.inputs...)
						// Skip this sequence but continue processing the rest
						nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
						err = nil
						continue
					} else {
						return
					}
				}
			}

			batchInputs = append(batchInputs, seq.inputs[i])
			if inp.Multimodal != nil {
				var mm []input.Multimodal
				mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
				if err != nil {
					return
				}
				batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
			}

			batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
			batch.Sequences = append(batch.Sequences, seq.cache.Id)

			seq.iBatch = len(batch.Outputs)
			if i+1 == len(seq.inputs) {
				batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
			}
			slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
			seq.pendingInputs = append(seq.pendingInputs, inp)
		}

		seq.inputs = seq.inputs[len(seq.pendingInputs):]
	}

	if resumeSeq != -1 {
		s.nextSeq = resumeSeq
	} else {
		s.nextSeq = seqIdx + 1
	}

	if len(batchInputs) == 0 {
		slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
		nextBatch.ctx.Close()
		nextBatch.ctx = nil
		return
	}
	s.batchID++

	// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
	batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
	nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
	if err != nil {
		err = fmt.Errorf("failed to build graph: %w", err)
		return
	}
	nextBatch.batchInputs = batchInputs
	nextBatch.batch = batch

	return
}

// Async processing of the next batch
func (s *Server) computeBatch(activeBatch batchState) {
	if activeBatch.ctx == nil {
		// Nothing to compute
		return
	}
	defer activeBatch.ctx.Close()

	// Wait until inputs are ready
	slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
	<-activeBatch.inputsReadyCh
	slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)

	// Once we complete, signal the next batch of inputs are ready
	// This will unblock the next computeBatch, or forwardBatch if new seqs come in
	defer func() {
		slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
		activeBatch.outputsReadyCh <- struct{}{}
	}()

	s.mu.Lock()

	// Gather the actual input token values now that they're ready
	batchInputs := make([]int32, len(activeBatch.batchInputs))
	for i := range batchInputs {
		batchInputs[i] = activeBatch.batchInputs[i].Token
	}

	// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
	// so that forwardBatch can build a batchInputs set which will eventually contain the actual
	// decoded tokens.
	nextBatchTokens := make([]*input.Input, len(s.seqs))
	iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
	for i, seq := range s.seqs {
		iBatches[i] = -1
		if seq == nil {
			continue
		}
		// Skip over any newly added or skipped sequences
		if activeBatch.seqs[i] == nil {
			continue
		}

		// Detect if the sequence we're processing has already been completed and replaced
		// with a new sequence
		if seq != activeBatch.seqs[i] {
			slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
			continue
		}

		// Pending inputs will actually be in the cache after we call Compute.
		// However, we have already resolved any placeholder tokens.
		//
		// It's possible for incoming sequences to look at the values that we've
		// added to the cache here and start relying on them before we've done
		// the computation. This is OK as long as we ensure that this batch's
		// computation happens before any future batch's and we never fail
		// (unless we take down the whole runner).
		if len(seq.pendingInputs) > 0 {
			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
			seq.pendingInputs = []*input.Input{}
		}

		// don't sample prompt processing
		if len(seq.inputs) != 0 {
			if !s.cache.enabled {
				s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
				s.mu.Unlock()
				return
			}
			continue
		}

		seq.numPredicted++
		nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
		seq.inputs = []*input.Input{nextToken}
		nextBatchTokens[i] = nextToken
		iBatches[i] = seq.iBatch
	}

	// At this point the seqs are ready for forwardBatch to move forward so unblock
	s.mu.Unlock()

	activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
	activeBatch.ctx.ComputeWithNotify(
		func() {
			slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
			activeBatch.computeStartedCh <- struct{}{}
		},
		activeBatch.modelOutput)
	logits := activeBatch.modelOutput.Floats()

	slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)

	s.mu.Lock()
	defer s.mu.Unlock()

	slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
	for i, seq := range s.seqs {
		if seq == nil || nextBatchTokens[i] == nil {
			continue
		}

		if seq.numPredicted == 1 {
			seq.startGenerationTime = time.Now()
		}

		// if done processing the prompt, generate an embedding and return
		if seq.embeddingOnly {
			// TODO(jessegross): Embedding support
			slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
			s.removeSequence(i, llm.DoneReasonStop)
			continue
		}

		// sample a token
		vocabSize := len(logits) / len(activeBatch.batch.Outputs)
		slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
		token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
		if err != nil {
			s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
			return
		}

		nextBatchTokens[i].Token = token

		// if it's an end of sequence token, break
		if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
			// TODO (jmorganca): we should send this back
			// as it's important for the /api/generate context
			// seq.responses <- piece
			slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
			s.removeSequence(i, llm.DoneReasonStop)
			continue
		}

		piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
		if err != nil {
			s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
			return
		}

		seq.pendingResponses = append(seq.pendingResponses, piece)
		sequence := strings.Join(seq.pendingResponses, "")

		if ok, stop := common.FindStop(sequence, seq.stop); ok {
			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

			var tokenTruncated bool
			origLen := len(seq.pendingResponses)
			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
			newLen := len(seq.pendingResponses)

			// Update the cache based on the tokens that will be returned:
			// - We have 1 token more than is currently in the cache because
			// the last one generated wasn't submitted to Decode
			// - Remove any stop sequences that we stripped out
			// - If truncateStop removed a portion of a token, drop that
			// - As defense-in-depth, if truncatedToken didn't find a stop token
			// remove the extra one that we added to the cache len
			tokenLen := len(seq.cache.Inputs) + 1
			tokenLen -= origLen - newLen
			if tokenTruncated || origLen == newLen {
				tokenLen--
			}

			seq.cache.Inputs = seq.cache.Inputs[:tokenLen]

			s.removeSequence(i, llm.DoneReasonStop)
			continue
		}

		if common.ContainsStopSuffix(sequence, seq.stop) {
			continue
		}

		if common.IncompleteUnicode(sequence) {
			continue
		}

		if !flushPending(seq) {
			s.removeSequence(i, llm.DoneReasonConnectionClosed)
		}
	}
}

func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
	var req llm.CompletionRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "Bad request", http.StatusBadRequest)
		return
	}

	if req.Options == nil {
		opts := api.DefaultOptions()
		req.Options = &opts
	}

	// Set the headers to indicate streaming
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Transfer-Encoding", "chunked")

	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
		return
	}

	var grammar *sample.GrammarSampler
	var err error
	if req.Grammar != "" {
		grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
		if err != nil {
			http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
			return
		}
		defer grammar.Free()
	}

	sampler := sample.NewSampler(
		req.Options.Temperature,
		req.Options.TopK,
		req.Options.TopP,
		req.Options.MinP,
		req.Options.Seed,
		grammar,
	)

	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
		numPredict: req.Options.NumPredict,
		stop:       req.Options.Stop,
		numKeep:    int32(req.Options.NumKeep),
		sampler:    sampler,
		embedding:  false,
	})
	if err != nil {
		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
		return
	}

	// Ensure there is a place to put the sequence, released when removed from s.seqs
	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
		if errors.Is(err, context.Canceled) {
			slog.Info("aborting completion request due to client closing the connection")
		} else {
			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
		}
		return
	}

	s.mu.Lock()
	found := false
	for i, sq := range s.seqs {
		if sq == nil {
			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
			if err != nil {
				s.mu.Unlock()
				s.seqsSem.Release(1)
				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
				return
			}

			s.seqs[i] = seq
			s.cond.Signal()
			found = true
			break
		}
	}
	s.mu.Unlock()

	if !found {
		s.seqsSem.Release(1)
		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
		return
	}

	for {
		select {
		case <-r.Context().Done():
			close(seq.quit)
			return
		case content, ok := <-seq.responses:
			if ok {
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
					Content: content,
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
					close(seq.quit)
					return
				}

				flusher.Flush()
			} else {
				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
					Done:               true,
					DoneReason:         seq.doneReason,
					PromptEvalCount:    seq.numPromptInputs,
					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
					EvalCount:          seq.numPredicted,
					EvalDuration:       time.Since(seq.startGenerationTime),
				}); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
				}

				return
			}
		}
	}
}

func (s *Server) health(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
		Status:   s.status,
		Progress: s.progress,
	}); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
	}
}

func (s *Server) reserveWorstCaseGraph() error {
	ctx := s.model.Backend().NewContext()
	defer ctx.Close()

	var err error
	inputs := make([]*input.Input, s.batchSize)
	for i := range inputs {
		inputs[i] = &input.Input{}
	}
	mmStore := newMultimodalStore()

	// Multimodal strategy:
	// - Encode a 2048x2048 image. This assumes that a single image of this
	//   size is sufficient to trigger the worst case. This is currently true
	//   because for existing models, only a single image fits in a batch.
	// - Add the embedding to a full batch of tokens - this is necessary because
	//   the model may be looking for non-image data, such as <image> tags.
	// - Run PostTokenize to execute any transformations between generated
	//   embeddings and what the forward pass expects.
	// - The result may now be larger than a batch (images may not fit in a
	//   single batch), so trim based on what will fit and must be grouped together.
	// - Fill out the rest of the space with text tokens.
	if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
		mmCtx := s.model.Backend().NewContext()
		defer mmCtx.Close()

		img := image.NewGray(image.Rect(0, 0, 2048, 2048))
		var buf bytes.Buffer
		bmp.Encode(&buf, img)

		if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
			mmStore.addMultimodal(inputs[0].Multimodal)

			inputs, err = multimodalProcessor.PostTokenize(inputs)
			if err != nil {
				return err
			}

			for i, inp := range inputs {
				minBatch := 1 + inp.SameBatch
				if minBatch > s.batchSize {
					inputs = inputs[i:min(i+minBatch, len(inputs))]
					break
				} else if i+minBatch > s.batchSize {
					inputs = inputs[:i]
					break
				}
			}

			if len(inputs) < s.batchSize {
				newInputs := make([]*input.Input, s.batchSize)
				copy(newInputs, inputs)
				for i := len(inputs); i < s.batchSize; i++ {
					newInputs[i] = &input.Input{}
				}
				inputs = newInputs
			}
		}
	}

	var batch input.Batch

	batchInputs := make([]int32, len(inputs))
	batch.Positions = make([]int32, len(inputs))
	batch.Sequences = make([]int, len(inputs))
	for i, inp := range inputs {
		batchInputs[i] = inp.Token
		if inp.Multimodal != nil {
			mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
			if err != nil {
				return err
			}
			batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
		}

		batch.Positions[i] = int32(i)
	}

	batch.Outputs = make([]int32, s.parallel)
	for i := range batch.Outputs {
		batch.Outputs[i] = int32(i)
	}

	batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))

	cache := s.model.Config().Cache
	if cache != nil {
		err := cache.StartForward(ctx, batch, true)
		if err != nil {
			return err
		}
	}

	t, err := s.model.Forward(ctx, batch)
	if err != nil {
		return err
	}

	ctx.Forward(t).Reserve()

	return nil
}

// allocModel pre-allocates the maximum needed memory for a model
// based on the given parameters
func (s *Server) allocModel(
	mpath string,
	params ml.BackendParams,
	loraPath []string,
	parallel int,
	kvCacheType string,
	kvSize int,
	multiUserCache bool,
) (panicErr error) {
	// Convert memory allocation panics to errors
	defer func() {
		if r := recover(); r != nil {
			debug.PrintStack()
			if err, ok := r.(error); ok {
				panicErr = err
			} else {
				panic(r)
			}
		}
	}()

	var err error
	s.model, err = model.New(mpath, params)
	if err != nil {
		return err
	}

	// TODO(jessegross): LoRA loading
	if len(loraPath) > 0 {
		return errors.New("loras are not yet implemented")
	}

	s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
	if err != nil {
		return err
	}

	if !s.cache.enabled && parallel > 1 {
		parallel = 1
		slog.Warn("model does not support caching, disabling parallel processing")
	}

	s.parallel = parallel
	s.seqs = make([]*Sequence, s.parallel)
	s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

	return s.reserveWorstCaseGraph()
}

// closeModel frees all memory associated with a model
func (s *Server) closeModel() {
	s.cache.Close()
	s.cache = nil
	if s.model != nil {
		s.model.Backend().Close()
		s.model = nil
	}
}

// loadModel loads the weights for a model. The memory must already
// have been allocated with allocModel
func (s *Server) loadModel() {
	err := s.model.Backend().Load(context.TODO(),
		func(progress float32) {
			s.progress = progress
		})
	if err != nil {
		panic(fmt.Errorf("failed to load model: %v", err))
	}

	s.status = llm.ServerStatusReady
	s.ready.Done()
}

// load is the handler called by the Ollama server to process different
// load operations
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
	s.loadMu.Lock()
	defer s.loadMu.Unlock()

	w.Header().Set("Content-Type", "application/json")

	if s.status != llm.ServerStatusLaunched {
		http.Error(w, "model already loaded", http.StatusInternalServerError)
		return
	}

	var req llm.LoadRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "bad request", http.StatusBadRequest)
		return
	}

	slog.Info("load", "request", req)

	if req.Operation == llm.LoadOperationClose {
		s.closeModel()
		if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
			http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
		}
		return
	}

	s.lastLoad.Operation = req.Operation
	loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)

	s.lastLoad = req

	if loadModel {
		s.closeModel()

		params := ml.BackendParams{
			AllocMemory:    req.Operation != llm.LoadOperationFit,
			NumThreads:     req.NumThreads,
			GPULayers:      req.GPULayers,
			FlashAttention: req.FlashAttention,
		}

		s.batchSize = req.BatchSize

		err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache)
		if err != nil {
			s.closeModel()

			var noMem ml.ErrNoMem
			if errors.As(err, &noMem) {
				resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory}
				if err := json.NewEncoder(w).Encode(&resp); err != nil {
					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
				}

				return
			}

			http.Error(w, fmt.Sprintf("failed to initialize model: %v", err), http.StatusInternalServerError)
			return
		}
	}

	mem := s.model.Backend().BackendMemory()

	switch req.Operation {
	case llm.LoadOperationFit:
		// LoadOperationFit can't be used for anything else, so just close it
		s.closeModel()

	// LoadOperationAlloc should stay open for future operations

	case llm.LoadOperationCommit:
		s.status = llm.ServerStatusLoadingModel
		go s.loadModel()
	}

	resp := llm.LoadResponse{Success: true, Memory: mem}
	if err := json.NewEncoder(w).Encode(&resp); err != nil {
		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
		return
	}
}

func Execute(args []string) error {
	fs := flag.NewFlagSet("runner", flag.ExitOnError)
	mpath := fs.String("model", "", "Path to model binary file")
	port := fs.Int("port", 8080, "Port to expose the server on")
	_ = fs.Bool("verbose", false, "verbose output (default: disabled)")

	fs.Usage = func() {
		fmt.Fprintf(fs.Output(), "Runner usage\n")
		fs.PrintDefaults()
	}
	if err := fs.Parse(args); err != nil {
		return err
	}
	slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
	slog.Info("starting ollama engine")

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	server := &Server{
		modelPath: *mpath,
		status:    llm.ServerStatusLaunched,
		hardErrCh: make(chan error, 1),
	}

	server.cond = sync.NewCond(&server.mu)
	server.ready.Add(1)

	go server.run(ctx)

	addr := "127.0.0.1:" + strconv.Itoa(*port)
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		fmt.Println("Listen error:", err)
		return err
	}
	defer listener.Close()

	mux := http.NewServeMux()
	// TODO: support embeddings
	mux.HandleFunc("POST /load", server.load)
	mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
		http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
	})

	mux.HandleFunc("POST /completion", server.completion)
	mux.HandleFunc("GET /health", server.health)

	httpServer := http.Server{
		Handler: mux,
	}

	log.Println("Server listening on", addr)
	if err := httpServer.Serve(listener); err != nil {
		log.Fatal("server error:", err)
		return err
	}

	return nil
}
